---
title: Custom retriever
description: An overview of retrievers and supported type in the Multi-Agent Orchestrator System
---

The Multi-Agent Orchestrator System allows you to create custom retrievers by extending the abstract Retriever class. This flexibility enables you to integrate various data sources and retrieval methods into your agent system. In this guide, we'll walk through the process of creating a custom retriever, provide an example using OpenSearch Serverless, and explain how to set the retriever for a BedrockLLMAgent.

import { Tabs, TabItem } from '@astrojs/starlight/components';

## Steps to Create a Custom Retriever

<Tabs syncKey="runtime">
  <TabItem label="TypeScript" icon="seti:typescript" color="blue">
    1. Create a new class that extends the `Retriever` abstract class.
    2. Implement the required abstract methods: `retrieve`, `retrieveAndCombineResults`, and `retrieveAndGenerate`.
    3. Add any additional methods or properties specific to your retriever.
  </TabItem>
  <TabItem label="Python" icon="seti:python">
    1. Create a new class that inherits from the `Retriever` abstract base class.
    2. Implement the required abstract methods: `retrieve`, `retrieve_and_combine_results`, and `retrieve_and_generate`.
    3. Add any additional methods or properties specific to your retriever.
  </TabItem>
</Tabs>

## Example: OpenSearchServerless Retriever

Here's an example of a custom retriever that uses OpenSearch Serverless:

<Tabs syncKey="runtime">
  <TabItem label="TypeScript" icon="seti:typescript" color="blue">
    Install Opensearch npm package:

    ```bash
    npm install "@opensearch-project/opensearch"
    ```

    ```typescript
    import { Retriever } from "multi-agent-orchestrator";
    import { Client } from "@opensearch-project/opensearch";
    import { AwsSigv4Signer } from "@opensearch-project/opensearch/aws";
    import { defaultProvider } from "@aws-sdk/credential-provider-node";
    import { BedrockRuntimeClient, InvokeModelCommand } from "@aws-sdk/client-bedrock-runtime";

    /**
     * Interface for OpenSearchServerlessRetriever options
     */
    export interface OpenSearchServerlessRetrieverOptions {
      collectionEndpoint: string;
      index: string;
      region: string;
      vectorField: string;
      textField: string;
      k: number;
    }

    /**
     * OpenSearchServerlessRetriever class for interacting with OpenSearch Serverless
     * Extends the base Retriever class
     */
    export class OpenSearchServerlessRetriever extends Retriever {
      private client: Client;
      private bedrockClient: BedrockRuntimeClient;

      constructor(options: OpenSearchServerlessRetrieverOptions) {
        super(options);

        if (!options.collectionEndpoint || !options.index || !options.region) {
          throw new Error("collectionEndpoint, index, and region are required in options");
        }

        this.client = new Client({
          ...AwsSigv4Signer({
            region: options.region,
            service: 'aoss',
            getCredentials: () => defaultProvider()(),
          }),
          node: options.collectionEndpoint,
        });

        this.bedrockClient = new BedrockRuntimeClient({ region: options.region });

        this.options.vectorField = options.vectorField;
        this.options.textField = options.textField;
        this.options.k = options.k;
      }

      async retrieve(text: string): Promise<any> {
        try {
          const embeddings = await this.getEmbeddings(text);
          const results = await this.client.search({
            index: this.options.index,
            body: {
              _source: {
                excludes: [this.options.vectorField]
              },
              query: {
                bool: {
                  must: [
                    {
                      knn: {
                        [this.options.vectorField]: { vector: embeddings, k: this.options.k },
                      },
                    },
                  ],
                },
              },
              size: this.options.k,
            },
          });

          return results.body.hits.hits;
        } catch (error) {
          throw new Error(`Failed to retrieve: ${error instanceof Error ? error.message : String(error)}`);
        }
      }

      private async getEmbeddings(text: string): Promise<number[]> {
        try {
          const response = await this.bedrockClient.send(
            new InvokeModelCommand({
              modelId: "amazon.titan-embed-text-v2:0",
              body: JSON.stringify({
                inputText: text,
              }),
              contentType: "application/json",
              accept: "application/json",
            })
          );

          const body = new TextDecoder().decode(response.body);
          const embeddings = JSON.parse(body).embedding;

          if (!Array.isArray(embeddings)) {
            throw new Error("Invalid embedding format received from Bedrock");
          }

          return embeddings;
        } catch (error) {
          throw new Error(`Failed to get embeddings: ${error instanceof Error ? error.message : String(error)}`);
        }
      }

      async retrieveAndCombineResults(text: string): Promise<string> {
        try {
          const results = await this.retrieve(text);
          return results
            .filter((hit: any) => hit._source && hit._source[this.options.textField])
            .map((hit: any) => hit._source[this.options.textField])
            .join("\n");
        } catch (error) {
          throw new Error(`Failed to retrieve and combine results: ${error instanceof Error ? error.message : String(error)}`);
        }
      }

      async retrieveAndGenerate(text: string): Promise<string> {
        return this.retrieveAndCombineResults(text);
      }

      async updateDocument(id: string, content: any): Promise<any> {
        try {
          const response = await this.client.update({
            index: this.options.index,
            id: id,
            body: {
              doc: content
            }
          });
          return response.body;
        } catch (error) {
          throw new Error(`Failed to update document: ${error instanceof Error ? error.message : String(error)}`);
        }
      }
    }
    ```
  </TabItem>
  <TabItem label="Python" icon="seti:python">
    Install required Python packages:

    ```bash
    pip install opensearch-py boto3
    ```

    ```python
    from typing import Any, Dict, List
    from multi_agent_orchestrator.retrievers import Retriever
    from opensearchpy import OpenSearch, RequestsHttpConnection
    from requests_aws4auth import AWS4Auth
    import boto3
    import json

    class OpenSearchServerlessRetrieverOptions:
        def __init__(self, collection_endpoint: str, index: str, region: str, vector_field: str, text_field: str, k: int):
            self.collection_endpoint = collection_endpoint
            self.index = index
            self.region = region
            self.vector_field = vector_field
            self.text_field = text_field
            self.k = k

    class OpenSearchServerlessRetriever(Retriever):
        def __init__(self, options: OpenSearchServerlessRetrieverOptions):
            super().__init__(options)
            self.options = options

            if not all([options.collection_endpoint, options.index, options.region]):
                raise ValueError("collection_endpoint, index, and region are required in options")

            credentials = boto3.Session().get_credentials()
            awsauth = AWS4Auth(credentials.access_key, credentials.secret_key, options.region, 'aoss',
                               session_token=credentials.token)

            self.client = OpenSearch(
                hosts=[{'host': options.collection_endpoint, 'port': 443}],
                http_auth=awsauth,
                use_ssl=True,
                verify_certs=True,
                connection_class=RequestsHttpConnection
            )

            self.bedrock_client = boto3.client('bedrock-runtime', region_name=options.region)

        async def retrieve(self, text: str) -> List[Dict[str, Any]]:
            try:
                embeddings = await self.get_embeddings(text)
                query = {
                    "_source": {
                        "excludes": [self.options.vector_field]
                    },
                    "query": {
                        "knn": {
                            self.options.vector_field: {
                                "vector": embeddings,
                                "k": self.options.k
                            }
                        }
                    },
                    "size": self.options.k
                }
                response = self.client.search(index=self.options.index, body=query)
                return response['hits']['hits']
            except Exception as e:
                raise Exception(f"Failed to retrieve: {str(e)}")

        async def get_embeddings(self, text: str) -> List[float]:
            try:
                response = self.bedrock_client.invoke_model(
                    modelId="amazon.titan-embed-text-v2:0",
                    body=json.dumps({"inputText": text}),
                    contentType="application/json",
                    accept="application/json"
                )
                embeddings = json.loads(response['body'].read())['embedding']
                if not isinstance(embeddings, list):
                    raise ValueError("Invalid embedding format received from Bedrock")
                return embeddings
            except Exception as e:
                raise Exception(f"Failed to get embeddings: {str(e)}")

        async def retrieve_and_combine_results(self, text: str) -> str:
            try:
                results = await self.retrieve(text)
                return "\n".join(
                    hit['_source'][self.options.text_field]
                    for hit in results
                    if self.options.text_field in hit['_source']
                )
            except Exception as e:
                raise Exception(f"Failed to retrieve and combine results: {str(e)}")

        async def retrieve_and_generate(self, text: str) -> str:
            return await self.retrieve_and_combine_results(text)

        async def update_document(self, id: str, content: Dict[str, Any]) -> Dict[str, Any]:
            try:
                response = self.client.update(
                    index=self.options.index,
                    id=id,
                    body={"doc": content}
                )
                return response
            except Exception as e:
                raise Exception(f"Failed to update document: {str(e)}")
    ```
  </TabItem>
</Tabs>

## Using the Custom Retriever with BedrockLLMAgent

To use your custom OpenSearchServerlessRetriever:

<Tabs syncKey="runtime">
  <TabItem label="TypeScript" icon="seti:typescript" color="blue">
    ```typescript
    import { BedrockLLMAgent } from './path-to-bedrockLLMAgent';

    const agent = new BedrockLLMAgent({
      name: 'My Bedrock Agent with OpenSearch Serverless',
      description: 'An agent that uses OpenSearch Serverless for retrieval',
      retriever: new OpenSearchServerlessRetriever({
          collectionEndpoint: "https://xxxxxxxxxxx.us-east-1.aoss.amazonaws.com",
          index: "vector-index",
          region: process.env.AWS_REGION!,
          textField: "textField",
          vectorField: "vectorField",
          k: 5,
        })
    });

    // Example usage
    const query = "What is the capital of France?";
    const response = await agent.processRequest(query, 'user123', 'session456', []);
    console.log(response);
    ```
  </TabItem>
  <TabItem label="Python" icon="seti:python">
    ```python
    from multi_agent_orchestrator.agents import BedrockLLMAgent, BedrockLLMAgentOptions
    from custom_retriever import OpenSearchServerlessRetriever, OpenSearchServerlessRetrieverOptions
    import os

    agent = BedrockLLMAgent(BedrockLLMAgentOptions(
        name='My Bedrock Agent with OpenSearch Serverless',
        description='An agent that uses OpenSearch Serverless for retrieval',
        retriever=OpenSearchServerlessRetriever(OpenSearchServerlessRetrieverOptions(
            collection_endpoint="https://xxxxxxxxxxx.us-east-1.aoss.amazonaws.com",
            index="vector-index",
            region=os.environ.get('AWS_REGION'),
            text_field="textField",
            vector_field="vectorField",
            k=5
        ))
    ))

    # Example usage
    query = "What is the capital of France?"
    response = await agent.process_request(query, 'user123', 'session456', [])
    print(response)
    ```
  </TabItem>
</Tabs>

In this example, we create an instance of our custom `OpenSearchServerlessRetriever` and then pass it to the `BedrockLLMAgent` constructor using the `retriever` field. This allows the agent to use your custom retriever for enhanced knowledge retrieval during request processing.

## How BedrockLLMAgent Uses the Retriever

When a BedrockLLMAgent processes a request and a retriever is set, it typically follows these steps:

1. The agent receives a user query through the `processRequest` method.
2. Before sending the query to the language model, the agent calls the retriever's `retrieveAndCombineResults` method with the user's query.
3. The retriever fetches relevant information from its data source (in this case, OpenSearch Serverless).
4. The retrieved information is combined and added to the context sent to the language model.
5. The language model then generates a response based on both the user's query and the additional context provided by the retriever.

This process allows the agent to leverage external knowledge sources, potentially improving the accuracy and relevance of its responses.

---

By adapting this example as a CustomRetriever for OpenSearch Serverless, you can seamlessly incorporate your **pre-built Opensearch Serverless clusters** into the Multi-Agent Orchestrator System, enhancing your agents' knowledge retrieval capabilities.